import torch
from model.graph_enc import node_mask_hard

if __name__ == '__main__':
    batch = torch.tensor([[[1,2,3],[4,5,6], [7,8,9]]])
    adjs = torch.zeros(1, 3, 3)
    adjs[0, 0, 1] = 1
    adjs[0, 1, 0] = 1
    print(batch)
    print(batch.size())
    print(adjs)
